# # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# # Licensed under the Apache License, Version 2.0 (the "License");
# # you may not use this file except in compliance with the License.
# # You may obtain a copy of the License at
# #
# #     http://www.apache.org/licenses/LICENSE-2.0
# #
# # Unless required by applicable law or agreed to in writing, software
# # distributed under the License is distributed on an "AS IS" BASIS,
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# # See the License for the specific language governing permissions and
# # limitations under the License.

# import os

# os.environ['FLAGS_enable_pir_api'] = '0'

# import tempfile
# import unittest

# import paddle
# from paddle.dataset.common import download
# from paddle.distributed.fleet.dataset import TreeIndex

# paddle.enable_static()


# def create_feeds():
#     user_input = paddle.static.data(
#         name="item_id", shape=[-1, 1], dtype="int64"
#     )

#     item = paddle.static.data(name="unit_id", shape=[-1, 1], dtype="int64")

#     label = paddle.static.data(name="label", shape=[-1, 1], dtype="int64")
#     labels = paddle.static.data(name="labels", shape=[-1, 1], dtype="int64")

#     feed_list = [user_input, item, label, labels]
#     return feed_list


# class TestTreeIndex(unittest.TestCase):
#     def test_tree_index(self):
#         path = download(
#             "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb",
#             "tree_index_unittest",
#             "e2ba4561c2e9432b532df40546390efa",
#         )
#         '''
#         path = download(
#             "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb",
#             "tree_index_unittest", "cadec20089f5a8a44d320e117d9f9f1a")
#         '''
#         tree = TreeIndex("demo", path)
#         height = tree.height()
#         branch = tree.branch()
#         self.assertTrue(height == 5)
#         self.assertTrue(branch == 2)
#         self.assertEqual(tree.total_node_nums(), 25)
#         self.assertEqual(tree.emb_size(), 30)

#         # get_layer_codes
#         layer_node_ids = []
#         layer_node_codes = []
#         for i in range(tree.height()):
#             layer_node_codes.append(tree.get_layer_codes(i))
#             layer_node_ids.append(
#                 [node.id() for node in tree.get_nodes(layer_node_codes[-1])]
#             )

#         all_leaf_ids = [node.id() for node in tree.get_all_leaves()]
#         self.assertEqual(sum(all_leaf_ids), sum(layer_node_ids[-1]))

#         # get_travel
#         travel_codes = tree.get_travel_codes(all_leaf_ids[0])
#         travel_ids = [node.id() for node in tree.get_nodes(travel_codes)]

#         for i in range(height):
#             self.assertIn(travel_ids[i], layer_node_ids[height - 1 - i])
#             self.assertIn(travel_codes[i], layer_node_codes[height - 1 - i])

#         # get_ancestor
#         ancestor_codes = tree.get_ancestor_codes([all_leaf_ids[0]], height - 2)
#         ancestor_ids = [node.id() for node in tree.get_nodes(ancestor_codes)]

#         self.assertEqual(ancestor_ids[0], travel_ids[1])
#         self.assertEqual(ancestor_codes[0], travel_codes[1])

#         # get_pi_relation
#         pi_relation = tree.get_pi_relation([all_leaf_ids[0]], height - 2)
#         self.assertEqual(pi_relation[all_leaf_ids[0]], ancestor_codes[0])

#         # get_travel_path
#         travel_path_codes = tree.get_travel_path(
#             travel_codes[0], travel_codes[-1]
#         )
#         travel_path_ids = [
#             node.id() for node in tree.get_nodes(travel_path_codes)
#         ]

#         self.assertEqual([*travel_path_ids, travel_ids[-1]], travel_ids)
#         self.assertEqual([*travel_path_codes, travel_codes[-1]], travel_codes)

#         # get_children
#         children_codes = tree.get_children_codes(travel_codes[1], height - 1)
#         children_ids = [node.id() for node in tree.get_nodes(children_codes)]
#         self.assertIn(all_leaf_ids[0], children_ids)


# class TestIndexSampler(unittest.TestCase):
#     def setUp(self):
#         self.temp_dir = tempfile.TemporaryDirectory()

#     def tearDown(self):
#         self.temp_dir.cleanup()

#     def test_layerwise_sampler(self):
#         path = download(
#             "https://paddlerec.bj.bcebos.com/tree-based/data/mini_tree.pb",
#             "tree_index_unittest",
#             "e2ba4561c2e9432b532df40546390efa",
#         )

#         tdm_layer_counts = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
#         # tree = TreeIndex("demo", path)
#         file_name = os.path.join(
#             self.temp_dir.name, "test_in_memory_dataset_tdm_sample_run.txt"
#         )
#         with open(file_name, "w") as f:
#             # data = "29 d 29 d 29 29 29 29 29 29 29 29 29 29 29 29\n"
#             data = "1 1 1 15 15 15\n"
#             data += "1 1 1 15 15 15\n"
#             f.write(data)

#         slots = ["slot1", "slot2", "slot3"]
#         slots_vars = []
#         for slot in slots:
#             var = paddle.static.data(name=slot, shape=[-1, 1], dtype="int64")
#             slots_vars.append(var)

#         dataset = paddle.distributed.InMemoryDataset()
#         dataset.init(
#             batch_size=1,
#             pipe_command="cat",
#             download_cmd="cat",
#             use_var=slots_vars,
#         )
#         dataset.set_filelist([file_name])
#         # dataset.update_settings(pipe_command="cat")
#         # dataset._init_distributed_settings(
#         #    parse_ins_id=True,
#         #    parse_content=True,
#         #    fea_eval=True,
#         #    candidate_size=10000)

#         dataset.load_into_memory()
#         dataset.tdm_sample(
#             'demo',
#             tree_path=path,
#             tdm_layer_counts=tdm_layer_counts,
#             start_sample_layer=1,
#             with_hierarchy=False,
#             seed=0,
#             id_slot=2,
#         )
#         self.assertTrue(dataset.get_shuffle_data_size() == 8)


# if __name__ == '__main__':
#     unittest.main()
